WGAN-GP

WGAN-GP

背景介绍

  WGAN-GP(Wasserstein Generative Adversarial Networks-Gradient Penalty):于2017年发表于NIPS是WGAN的升级版本,WGAN理论的前提是1-Liposchitz条件,WGAN中使用的方法是权重裁剪,这不是一个非常好的办法,WGAN-GP使用了一种GP(Gradient Penalty, 梯度惩罚)的方法替代权重裁剪,构建了一个更加稳定,收敛更快,质量更高的生成式对抗网络

wgan-gp

权重裁剪的缺陷

  极大的限制了网络的表现能力,因为对权重进行的范围限制,使得网络很难模拟出想要得到的函数
  容易出现梯度爆炸和梯度消失的现象,而且本来生成式对抗网络的参数设置就很敏感,这样更难收敛

WGAN-GP的特点

  保持GAN的网络结构不变,将判别器网络最后的sigmoid删去
  使用随机采样作为惩罚项,不用WGAN中的权重裁剪
  重新采用Adam优化器,不存在WGAN种Adam稳定性不高的问题

WGAN-GP图像分析

generator
discriminator

TensorFlow2.0实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import numpy as np
import cv2 as cv
from functools import reduce, partial
import tensorflow as tf
import tensorflow.keras as keras


def compose(*funcs):
if funcs:
return reduce(lambda f, g: lambda *a, **kw: g(f(*a, **kw)), funcs)
else:
raise ValueError('Composition of empty sequence not supported.')


def generator(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = compose(keras.layers.Dense(256, activation='relu', name='dense_relu1'),
keras.layers.BatchNormalization(momentum=0.8),
keras.layers.Dense(512, activation='relu', name='dense_relu2'),
keras.layers.BatchNormalization(momentum=0.8),
keras.layers.Dense(1024, activation='relu', name='dense_relu3'),
keras.layers.BatchNormalization(momentum=0.8),
keras.layers.Dense(784, activation='tanh', name='dense_tanh'),
keras.layers.Reshape((28, 28, 1)))(x)

model = keras.Model(input_tensor, x, name='WGANGP-Generator')

return model


def discriminator(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = compose(keras.layers.Flatten(name='flatten'),
keras.layers.Dense(512, activation='relu', name='dense_relu1'),
keras.layers.Dense(256, activation='relu', name='dense_relu2'),
keras.layers.Dense(1, name='dense'))(x)

model = keras.Model(input_tensor, x, name='WGANGP-Discriminator')

return model


class Weight_Addition(keras.layers.Layer):
def __init__(self, name):
super(Weight_Addition, self).__init__()
self._name = name

def call(self, inputs, **kwargs):

r = np.random.uniform(0, 1, (batch_size, 1, 1, 1))

return inputs[0] * r + inputs[1] * (1 - r)


def wgan_gp(input_image_shape, input_noise_shape, model_g, model_d):
input_image = keras.layers.Input(input_image_shape, name='input_image')
input_noise_d = keras.layers.Input(input_noise_shape, name='input_noise_discriminator')
input_noise_g = keras.layers.Input(input_noise_shape, name='input_noise_generator')

model_g.trainable = False
fake_image = model_g(input_noise_d)
real_conf = model_d(input_image)
fake_conf = model_d(fake_image)
random_sample = Weight_Addition(name='random_sample')([input_image, fake_image])
sample_conf = model_d(random_sample)
partial_gp_loss = partial(gp_loss, random_sample=random_sample)
partial_gp_loss.__name__ = 'gp_loss'
model_discriminator = keras.Model([input_image, input_noise_d], [real_conf, fake_conf, sample_conf], name='WGAN-GP-discriminator')
model_discriminator.compile(optimizer=optimizer_d, loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss], loss_weights=[1, 1, 10])

model_g.trainable = True
model_d.trainable = False

image = model_g(input_noise_g)
conf = model_d(image)

model_generator = keras.Model(input_noise_g, conf, name='WGAN-GP-generator')
model_generator.compile(optimizer=optimizer_g, loss=wasserstein_loss)

return model_generator, model_discriminator


def save_picture(image, save_path, picture_num):
image = ((image + 1) * 127.5).astype(np.uint8)
image = np.concatenate([image[i * picture_num:(i + 1) * picture_num] for i in range(picture_num)], axis=2)
image = np.concatenate([image[i] for i in range(picture_num)], axis=0)
cv.imwrite(save_path, image)


def gp_loss(y_true, y_pred, random_sample):
gradients = tf.gradients(y_pred, random_sample)[0]
return tf.reduce_mean((tf.norm(gradients, 2) - 1.) ** 2)


def wasserstein_loss(y_true, y_pred):

return -tf.reduce_mean(y_true * y_pred)


if __name__ == '__main__':
(x, _), (_, _) = keras.datasets.mnist.load_data()
batch_size = 256
epochs = 50
tf.random.set_seed(22)
save_path = r'.\wgangp'
if not os.path.exists(save_path):
os.makedirs(save_path)

x = x[..., np.newaxis].astype(np.float32) / 127.5 - 1
x = tf.data.Dataset.from_tensor_slices(x).batch(batch_size)

optimizer_g = keras.optimizers.Adam(0.00005, 0.5)
optimizer_d = keras.optimizers.Adam(0.0002, 0.5)

real_dmean = keras.metrics.Mean()
fake_dmean = keras.metrics.Mean()
gmean = keras.metrics.Mean()

model_d = discriminator(input_shape=(28, 28, 1))

model_g = generator(input_shape=(100,))

model_g.build(input_shape=(100,))
model_g.summary()
keras.utils.plot_model(model_g, 'WGANGP-generator.png', show_shapes=True, show_layer_names=True)

model_d.build(input_shape=(28, 28, 1))
model_d.summary()
keras.utils.plot_model(model_d, 'WGANGP-discriminator.png', show_shapes=True, show_layer_names=True)

model_generator, model_discriminator = wgan_gp(input_image_shape=(28, 28, 1), input_noise_shape=(100,), model_g=model_g, model_d=model_d)

model_generator.build(input_shape=(100,))
model_generator.summary()
keras.utils.plot_model(model_generator, 'WGAN-GP-generator.png', show_shapes=True, show_layer_names=True)

model_discriminator.build(input_shape=[(28, 28, 1), (100,)])
model_discriminator.summary()
keras.utils.plot_model(model_discriminator, 'WGAN-GP-discriminator.png', show_shapes=True, show_layer_names=True)

for epoch in range(epochs):
x = x.shuffle(np.random.randint(0, 10000))
x_db = iter(x)

for step, real_image in enumerate(x_db):
if real_image.shape[0] != batch_size:
continue
noise = np.random.normal(0, 1, (real_image.shape[0], 100))

real_dmean(model_discriminator([real_image, noise])[0])
fake_dmean(model_discriminator([real_image, noise])[1])
gmean(model_generator(noise))

d_loss = model_discriminator.train_on_batch([real_image, noise], [np.ones((real_image.shape[0], 1)), -np.ones((real_image.shape[0], 1)), np.zeros((real_image.shape[0], 1))])
g_loss = model_generator.train_on_batch(noise, np.ones((real_image.shape[0], 1)))

if step % 20 == 0:
print('epoch = {}, step = {}, real_dmean = {}, fake_dmean = {}, gmean = {}'.format(epoch, step, real_dmean.result(), fake_dmean.result(), gmean.result()))
real_dmean.reset_states()
fake_dmean.reset_states()
gmean.reset_states()
fake_data = np.random.normal(0, 1, (100, 100))
fake_image = model_g(fake_data)
save_picture(fake_image.numpy(), save_path + '\\epoch{}_step{}.jpg'.format(epoch, step), 10)

wgan-gp

模型运行结果

wgan-gp

小技巧

  1. 图像输入可以先将其归一化到0-1之间或者-1-1之间,因为网络的参数一般都比较小,所以归一化后计算方便,收敛较快。
  2. 注意其中的一些维度变换和numpytensorflow常用操作,否则在阅读代码时可能会产生一些困难。
  3. 可以设置一些权重的保存方式学习率的下降方式早停方式
  4. WGAN-GP对于网络结构,优化器参数,网络层的一些超参数都是非常敏感的,效果不好不容易发现原因,这可能需要较多的工程实践经验
  5. 先创建判别器,然后进行compile,这样判别器就固定了,然后创建生成器时,不要训练判别器,需要将判别器的trainable改成False,此时不会影响之前固定的判别器,这个可以通过模型的_collection_collected_trainable_weights属性查看,如果该属性为空,则模型不训练,否则模型可以训练,compile之后,该属性固定,无论后面如何修改trainable,只要不重新compile,都不影响训练。
  6. 代码中使用了partial偏函数的概念,是functools中的内容,有关偏函数的使用,可以参考我的另一篇博客,Function(函数)
  7. 本博客中的WGAN-GP是在GAN的基础上进行修改,当然小伙伴们也可以尝试在DCGAN,CGAN等模型上进行尝试

WGAN-GP小结

  因为WGAN-GP基本没有修改网络结构,因此网络参数和GAN完全相同,WGAN-GP在提出时对网络的损失函数进行了大量的分析,使用随机采样的思想代替权重裁剪,增加了网络的表现能力和稳定性,这种GP思想对后面生成式对抗网络的发展有着巨大的推动作用,小伙伴们可以跳过WGAN的学习,但是WGAN-GP是需要我们了解的模型。

-------------本文结束感谢您的阅读-------------
0%